In [ ]:
%matplotlib widget

import matplotlib.pyplot as plt
from matplotlib.backend_bases import MouseButton, MouseEvent
import geoengine as ge
from geoengine.workflow_builder.operators import TemporalRasterAggregation, RasterStacker, RenameBands, \
    RasterTypeConversion, Expression, GdalSource, RasterVectorJoin, OgrSource, Onnx
import numpy as np
from typing import Tuple, Callable, Optional
import geopandas as gpd
import pandas as pd
from shapely.geometry import Point
import xarray as xr
from sklearn.ensemble import RandomForestClassifier
from skl2onnx import to_onnx # TODO: in ge library
from geoengine_openapi_client.models import MlModelMetadata, RasterDataType # TODO: better import
In [ ]:
def points_by_drawing(filename: str,
                      color: str,
                      crs: str,
                      background: Callable[[plt.Axes], None],
                      figsize: Optional[Tuple[int, int]] = None) -> None:

    try:
        points = gpd.read_file(filename)
        points.set_crs(crs, inplace=True, allow_override=True)
    except:
        points = gpd.GeoDataFrame(crs=crs, geometry=[])

    fig, ax = plt.subplots(figsize=figsize)

    background(ax)

    def on_click(event: MouseEvent):
        if event.button is not MouseButton.LEFT:
            return
        if not event.inaxes:
            return
        
        points.loc[len(points)] = [Point(event.xdata, event.ydata)]

        points.plot(ax=ax, c=color)

        points.to_file(filename)

    fig.canvas.mpl_connect('button_press_event', on_click)

    points.plot(ax=ax, c=color)
In [ ]:
koeln_center_utm = [356766, 5644819]
In [ ]:
ge.initialize("http://localhost:3030/api")
In [ ]:
def _query_rectangle(*,
                     center: Tuple[float, float],
                     time: np.datetime64,
                     radius_px: float = 512) -> ge.QueryRectangle:
    resolution = ge.SpatialResolution(10, 10)
    bbox = ge.BoundingBox2D(
        xmin=center[0] - resolution.x_resolution * radius_px,
        xmax=center[0] + resolution.x_resolution * radius_px,
        ymin=center[1] - resolution.y_resolution * radius_px,
        ymax=center[1] + resolution.y_resolution * radius_px,
    )
    return ge.QueryRectangle(
        spatial_bounds=bbox,
        time_interval=ge.TimeInterval(time),
        srs='EPSG:32632',
        resolution=resolution,
    )

red_band = GdalSource("_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:B04`")
green_band = GdalSource("_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:B03`")
blue_band = GdalSource("_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:B02`")
nir_band = GdalSource("_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:B08`")
mask_band = RasterTypeConversion(
    GdalSource("_:5779494c-f3a2-48b3-8a2d-5fbba8c5b6c5:`UTM32N:SCL`"),
    output_data_type='U16',
)

workflow = TemporalRasterAggregation(
    aggregation_type='mean',
    granularity='months',
    window_size=1,
    ignore_no_data=True,
    source=RasterStacker(
        sources=[
            Expression(
                expression="if (B == 3 || (B >= 7 && B <= 11)) { NODATA } else { A }",
                output_type="F32",
                source=RasterStacker([red_band, mask_band]),
            ),
            Expression(
                expression="if (B == 3 || (B >= 7 && B <= 11)) { NODATA } else { A }",
                output_type="F32",
                source=RasterStacker([green_band, mask_band]),
            ),
            Expression(
                expression="if (B == 3 || (B >= 7 && B <= 11)) { NODATA } else { A }",
                output_type="F32",
                source=RasterStacker([blue_band, mask_band]),
            ),
            Expression(
                expression="if (C == 3 || (C >= 7 && C <= 11)) { NODATA } else { (A - B) / (A + B) }",
                output_type="F32",
                source=RasterStacker([nir_band, red_band, mask_band]),
            ),
            # RasterTypeConversion(
            #     ge.workflow_builder.blueprints.sentinel2_cloud_free_band("B04"),
            #     output_data_type="F32",
            # ),
            # RasterTypeConversion(
            #     ge.workflow_builder.blueprints.sentinel2_cloud_free_band("B03"),
            #     output_data_type="F32",
            # ),
            # RasterTypeConversion(
            #     ge.workflow_builder.blueprints.sentinel2_cloud_free_band("B02"),
            #     output_data_type="F32",
            # ),
            # ge.workflow_builder.blueprints.sentinel2_cloud_free_ndvi(),
        ],
        rename=RenameBands.rename(['red', 'green', 'blue', 'ndvi']),
    )
)

workflow = ge.register_workflow(workflow)

query_rectangle = _query_rectangle(
    center=koeln_center_utm,
    time=np.datetime64("2022-07-01T00:00:00"),
    radius_px=512,
)

data_array = await workflow.raster_stream_into_xarray(
    query_rectangle=query_rectangle,
    clip_to_query_rectangle=True,
    bands=[0,1,2,3], # TODO: improve for user, default = all? where are the band names?
)
/home/beilschmidt/git/geoengine-python/env/lib/python3.10/site-packages/rasterio/windows.py:314: RasterioDeprecationWarning: The height, width, and precision parameters are unused, deprecated, and will be removed in 2.0.0.
  warnings.warn(
In [ ]:
points_by_drawing(
    "water.geojson",
    color="blue",
    crs=str(data_array.rio.crs),
    background=lambda ax: data_array.isel(time=0, band=[0,1,2]).plot.imshow(
        rgb="band",
        vmax=4000,
        ax=ax,
    ),
    figsize=(10,10),
)
/home/beilschmidt/git/geoengine-python/env/lib/python3.10/site-packages/matplotlib/cm.py:478: RuntimeWarning: invalid value encountered in cast
  xx = (xx * 255).astype(np.uint8)
Figure
No description has been provided for this image
In [ ]:
points_by_drawing(
    "non_water.geojson",
    color="green",
    crs=str(data_array.rio.crs),
    background=lambda ax: data_array.isel(time=0, band=[0,1,2]).plot.imshow(
        rgb="band",
        vmax=4000,
        ax=ax,
    ),
    figsize=(10,10),
)
/home/beilschmidt/git/geoengine-python/env/lib/python3.10/site-packages/matplotlib/cm.py:478: RuntimeWarning: invalid value encountered in cast
  xx = (xx * 255).astype(np.uint8)
Figure
No description has been provided for this image
In [ ]:
labels = pd.concat(
    [
        gpd.read_file("water.geojson").assign(water=1),
        gpd.read_file("non_water.geojson").assign(water=0),
    ],
    ignore_index=True,
).set_crs(
    str(data_array.rio.crs),
    allow_override=True,
# ).assign(
#     start="2022-07-01",
#     end="2022-08-01"
)

labels_name = ge.upload_dataframe(
    labels,
    # time=ge.datasets.OgrSourceDatasetTimeType.start_end(
    #     start_field="start",
    #     start_format=ge.datasets.OgrSourceTimeFormat.auto(),
    #     end_field="end",
    #     end_format=ge.datasets.OgrSourceTimeFormat.auto(),
    # )
)

labels_name
Out[ ]:
ddeeddd5-a45d-40cc-b692-a005f93643df:82b7a3b3-8b88-45a7-b21b-295549e7ff9f
In [ ]:
training_workflow = RasterVectorJoin(
    raster_sources=[workflow.workflow_definition().operator],
    vector_source=OgrSource(labels_name),
    names=ge.workflow_builder.operators.ColumnNames.default(),
    temporal_aggregation="none",
    feature_aggregation="first",
    temporal_aggregation_ignore_nodata=True,
    feature_aggregation_ignore_nodata=True,
)

training_workflow = ge.register_workflow(training_workflow)

training_df = training_workflow.get_dataframe(query_rectangle)

training_df.head()
Out[ ]:
geometry blue green ndvi red water start end
0 POINT (359492.047 5649059.959) 458.799988 485.799988 -0.158283 328.399994 1 2022-07-01 00:00:00+00:00 2022-08-01 00:00:00+00:00
1 POINT (359436.582 5648255.711) 395.500000 439.250000 -0.203166 279.000000 1 2022-07-01 00:00:00+00:00 2022-08-01 00:00:00+00:00
2 POINT (361045.078 5639935.905) 502.500000 561.666687 -0.239277 407.500000 1 2022-07-01 00:00:00+00:00 2022-08-01 00:00:00+00:00
3 POINT (359381.116 5640102.302) 416.000000 478.399994 -0.160315 337.200012 1 2022-07-01 00:00:00+00:00 2022-08-01 00:00:00+00:00
4 POINT (357966.750 5641599.867) 412.500000 486.000000 -0.079246 332.666656 1 2022-07-01 00:00:00+00:00 2022-08-01 00:00:00+00:00
In [ ]:
training_workflow.get_result_descriptor()
Out[ ]:
Data type:         MultiPoint
Spatial Reference: EPSG:32632
Columns:
  red:
    Column Type: float
    Measurement: unitless
  green:
    Column Type: float
    Measurement: unitless
  ndvi:
    Column Type: float
    Measurement: unitless
  blue:
    Column Type: float
    Measurement: unitless
  water:
    Column Type: int
    Measurement: unitless
In [ ]:
fig, ax = plt.subplots()

data_array.isel(time=0, band=[0,1,2]).plot.imshow(
    rgb="band",
    vmax=4000,
    ax=ax,
)

training_df.plot(ax=ax, column="water", k=2, colormap="winter_r")
/home/beilschmidt/git/geoengine-python/env/lib/python3.10/site-packages/geopandas/plotting.py:644: FutureWarning: 'colormap' is deprecated, please use 'cmap' instead (for consistency with matplotlib)
  warnings.warn(
Out[ ]:
<Axes: title={'center': 'time = 2022-07-01, spatial_ref = 0'}, xlabel='x coordinate of projection\n[metre]', ylabel='y coordinate of projection\n[metre]'>
/home/beilschmidt/git/geoengine-python/env/lib/python3.10/site-packages/matplotlib/cm.py:478: RuntimeWarning: invalid value encountered in cast
  xx = (xx * 255).astype(np.uint8)
Figure
No description has been provided for this image
In [ ]:
X = training_df[["red", "green", "blue", "ndvi"]].to_numpy().astype(np.float32)
print(X[:5])
[[ 3.2839999e+02  4.8579999e+02  4.5879999e+02 -1.5828317e-01]
 [ 2.7900000e+02  4.3925000e+02  3.9550000e+02 -2.0316632e-01]
 [ 4.0750000e+02  5.6166669e+02  5.0250000e+02 -2.3927733e-01]
 [ 3.3720001e+02  4.7839999e+02  4.1600000e+02 -1.6031471e-01]
 [ 3.3266666e+02  4.8600000e+02  4.1250000e+02 -7.9246230e-02]]
In [ ]:
y = training_df["water"].to_numpy().astype(np.uint8)
print(y[:5])
[1 1 1 1 1]
In [ ]:
clf = RandomForestClassifier(random_state=42)

clf.fit(X, y)
Out[ ]:
RandomForestClassifier(random_state=42)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomForestClassifier(random_state=42)
In [ ]:
onx = to_onnx(clf, X[:1], target_opset=9) # target_opset is the ONNX version to use
In [ ]:
# print(onx)

from onnx.checker import check_model

check_model(onx)
In [ ]:
model_name = f"{ge.get_session().user_id}:rf2"

metadata = MlModelMetadata(
    file_name="model.onnx", # always this?
    input_type=RasterDataType.F32,
    num_input_bands=4, # should know this
    output_type=RasterDataType.I64,
)

model_config = ge.ml.MlModelConfig(
    name=model_name,
    metadata=metadata,
    display_name="Decision Tree",
    description="A simple decision tree model",
)

ge.register_ml_model(onnx_model=onx, model_config=model_config)
In [ ]:
model_workflow = Onnx(
    source=workflow.workflow_definition().operator,
    model=model_name,
)

model_workflow = ge.register_workflow(model_workflow)

model_workflow.get_result_descriptor()
Out[ ]:
Data type:         I64
Spatial Reference: EPSG:32632
Bands:
    prediction: unitless
In [ ]:
classified_array = await model_workflow.raster_stream_into_xarray(
    query_rectangle=query_rectangle,
    clip_to_query_rectangle=True,
    bands=[0], # TODO: improve for user, default = all? where are the band names?
)

classified_array
/home/beilschmidt/git/geoengine-python/env/lib/python3.10/site-packages/rasterio/windows.py:314: RasterioDeprecationWarning: The height, width, and precision parameters are unused, deprecated, and will be removed in 2.0.0.
  warnings.warn(
Out[ ]:
<xarray.DataArray (time: 1, band: 1, y: 1025, x: 1025)>
array([[[[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]]]])
Coordinates:
  * x            (x) float64 3.516e+05 3.517e+05 ... 3.619e+05 3.619e+05
  * y            (y) float64 5.65e+06 5.65e+06 5.65e+06 ... 5.64e+06 5.64e+06
  * time         (time) datetime64[ns] 2022-07-01
  * band         (band) int64 0
    spatial_ref  int64 0
xarray.DataArray
  • time: 1
  • band: 1
  • y: 1025
  • x: 1025
  • 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
    array([[[[0, 0, 0, ..., 0, 0, 0],
             [0, 0, 0, ..., 0, 0, 0],
             [0, 0, 0, ..., 0, 0, 0],
             ...,
             [0, 0, 0, ..., 0, 0, 0],
             [0, 0, 0, ..., 0, 0, 0],
             [0, 0, 0, ..., 0, 0, 0]]]])
    • x
      (x)
      float64
      3.516e+05 3.517e+05 ... 3.619e+05
      axis :
      X
      long_name :
      x coordinate of projection
      standard_name :
      projection_x_coordinate
      units :
      metre
      array([351645., 351655., 351665., ..., 361865., 361875., 361885.])
    • y
      (y)
      float64
      5.65e+06 5.65e+06 ... 5.64e+06
      axis :
      Y
      long_name :
      y coordinate of projection
      standard_name :
      projection_y_coordinate
      units :
      metre
      array([5649935., 5649925., 5649915., ..., 5639715., 5639705., 5639695.])
    • time
      (time)
      datetime64[ns]
      2022-07-01
      array(['2022-07-01T00:00:00.000000000'], dtype='datetime64[ns]')
    • band
      (band)
      int64
      0
      array([0])
    • spatial_ref
      ()
      int64
      0
      crs_wkt :
      PROJCS["WGS 84 / UTM zone 32N",GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]],PROJECTION["Transverse_Mercator"],PARAMETER["latitude_of_origin",0],PARAMETER["central_meridian",9],PARAMETER["scale_factor",0.9996],PARAMETER["false_easting",500000],PARAMETER["false_northing",0],UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH],AUTHORITY["EPSG","32632"]]
      semi_major_axis :
      6378137.0
      semi_minor_axis :
      6356752.314245179
      inverse_flattening :
      298.257223563
      reference_ellipsoid_name :
      WGS 84
      longitude_of_prime_meridian :
      0.0
      prime_meridian_name :
      Greenwich
      geographic_crs_name :
      WGS 84
      horizontal_datum_name :
      World Geodetic System 1984
      projected_crs_name :
      WGS 84 / UTM zone 32N
      grid_mapping_name :
      transverse_mercator
      latitude_of_projection_origin :
      0.0
      longitude_of_central_meridian :
      9.0
      false_easting :
      500000.0
      false_northing :
      0.0
      scale_factor_at_central_meridian :
      0.9996
      spatial_ref :
      PROJCS["WGS 84 / UTM zone 32N",GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]],PROJECTION["Transverse_Mercator"],PARAMETER["latitude_of_origin",0],PARAMETER["central_meridian",9],PARAMETER["scale_factor",0.9996],PARAMETER["false_easting",500000],PARAMETER["false_northing",0],UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH],AUTHORITY["EPSG","32632"]]
      GeoTransform :
      351640.0 10.0 0.0 5649940.0 0.0 -10.0
      array(0)
    • x
      PandasIndex
      PandasIndex(Float64Index([351645.0, 351655.0, 351665.0, 351675.0, 351685.0, 351695.0,
                    351705.0, 351715.0, 351725.0, 351735.0,
                    ...
                    361795.0, 361805.0, 361815.0, 361825.0, 361835.0, 361845.0,
                    361855.0, 361865.0, 361875.0, 361885.0],
                   dtype='float64', name='x', length=1025))
    • y
      PandasIndex
      PandasIndex(Float64Index([5649935.0, 5649925.0, 5649915.0, 5649905.0, 5649895.0, 5649885.0,
                    5649875.0, 5649865.0, 5649855.0, 5649845.0,
                    ...
                    5639785.0, 5639775.0, 5639765.0, 5639755.0, 5639745.0, 5639735.0,
                    5639725.0, 5639715.0, 5639705.0, 5639695.0],
                   dtype='float64', name='y', length=1025))
    • time
      PandasIndex
      PandasIndex(DatetimeIndex(['2022-07-01'], dtype='datetime64[ns]', name='time', freq=None))
    • band
      PandasIndex
      PandasIndex(Int64Index([0], dtype='int64', name='band'))
In [ ]:
fig, ax = plt.subplots()

data_array.isel(time=0, band=[0,1,2]).plot.imshow(
    rgb="band",
    vmax=4000,
    ax=ax,
)

classified_array.isel(time=0, band=0).plot.imshow(ax=ax, levels=3, colors=["#00000000", "blue"], vmax=2, add_colorbar=False)

plt.show();
/home/beilschmidt/git/geoengine-python/env/lib/python3.10/site-packages/matplotlib/cm.py:478: RuntimeWarning: invalid value encountered in cast
  xx = (xx * 255).astype(np.uint8)
Figure
No description has been provided for this image